from ActualCausal.Train.EM.mask_utils import mask_assign, mask_most_common, mask_likelihood_probabilities, generate_masks
from ActualCausal.Train.Active.train_given_active import train_given_active
from ActualCausal.Train.Inter.train_trace import train_binaries
from ActualCausal.Train.Active.train_mask_active import train_masked_active
import numpy as np

#TODO: not updated yet, needs to return results
def train_EM(args, params, model, buffer, log_batch=[], wrap_function=None, additional = []):
    model_performances=None
    for i in range(args.train.num_iters):
        passive_mask = model.check_passive_mask(buffer.target)
        hard_masks, soft_masks = mask_assign(args, params, model, buffer, wrap_function=wrap_function) # gets he interaction output for each value
        k_masks = mask_most_common(hard_masks, model, args, params) # gets the k most common masks TODO: get the k most common mask relational functions, allow subset masks
        model_performances = list()
        weights = mask_likelihood_probabilities(k_masks, soft_masks, passive_mask, args.em.model_mask_weights, model_likelihoods=model_performances) # gets probability of data point under mask assignment and soft mask, TODO: don't penalize subset masks
        for model_index, mask in enumerate(k_masks): # TODO: multithreading on the forward model training
            params.weights = weights[model_index]
            result = train_given_active(args, params, model, buffer, mask,log_batch=log_batch, wrap_function=wrap_function, additional=additional)
            model_performances.append(result.loss_log_likelihoods)
        model_performances = np.concatenate(model_performances, axis=-1)
        interaction_masks = generate_masks(args, k_masks, np.expand_dims(passive_mask[0], axis=0), model_performances)
        train_binaries(args, params, model, buffer, interaction_masks, wrap_function=wrap_function, additional=additional)
        train_masked_active(args, params, model, buffer,log_batch=log_batch, wrap_function=wrap_function, additional=additional)

